#!/usr/bin/env python
"""GRR Colab module.

The module contains classes that Colab users will use to interact with GRR API.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import datetime
import io

from IPython.lib import pretty
from typing import Text, Sequence, List, Optional

from grr_api_client import client
from grr_api_client import errors as api_errors
from grr_colab import _api
from grr_colab import _timeout
from grr_colab import errors
from grr_colab import fs
from grr_colab import representer
from grr_colab import vfs
from grr_colab._textify import client as client_textify
from grr_response_proto import artifact_pb2
from grr_response_proto import flows_pb2
from grr_response_proto import jobs_pb2
from grr_response_proto import knowledge_base_pb2
from grr_response_proto import osquery_pb2
from grr_response_proto import sysinfo_pb2


def set_no_flow_timeout() -> None:
  """Disables flow timeout (it means wait forever).

  Returns:
    Nothing.
  """
  _timeout.set_timeout(None)


def set_default_flow_timeout() -> None:
  """Sets flow timeout to default value (30 seconds).

  Returns:
    Nothing.
  """
  _timeout.reset_timeout()


def set_flow_timeout(timeout: int) -> None:
  """Sets flow timeout.

  Args:
    timeout: timeout in seconds. 0 means not to wait.

  Returns:
    Nothing.
  """
  if timeout is None:
    raise ValueError('Timeout is not specified')
  if timeout < 0:
    raise ValueError('Timeout cannot be negative')
  _timeout.set_timeout(timeout)


def list_artifacts() -> Sequence[artifact_pb2.ArtifactDescriptor]:
  """Lists all registered artifacts.

  Returns:
    A list of artifact descriptors.
  """
  return [artifact.data for artifact in _api.get().ListArtifacts()]


class Client(object):
  """Wrapper for a GRR Client.

  Offers easy to use methods to interact with GRR API
  from Colab.

  Attributes:
    id: Id of the client.
    hostname: Hostname of the client.
    ifaces: A list of network interfaces of the given client.
    knowledgebase: Knowledgebase for the client.
    arch: Architectire that the client is running on.
    kernel: Kernel version string of the client.
    labels: A list of labels associated with the client.
    first_seen: Returns the time the client was seen for the first time.
    last_seen: Returns the time the client was seen for the last time.
    cached: A VFS instance that allows to work with filesystem data saved on the
      server that may not be up-to-date but is a way faster.
    os: OS filesystem instance that encapsulates filesystem related operations.
    tsk: TSK filesystem instance that encapsulates filesystem related
      operations.
    registry: REGISTRY filesystem instance that encapsulates filesystem related
      operations.
  """

  def __init__(self, client_: client.Client) -> None:
    self._client = client_
    self._summary = None  # type: jobs_pb2.ClientSummary

  @classmethod
  def with_id(cls, client_id: Text) -> 'Client':
    try:
      return cls(_api.get().Client(client_id).Get())
    except api_errors.UnknownError as e:
      raise errors.UnknownClientError(client_id, e)

  @classmethod
  def with_hostname(cls, hostname: Text) -> 'Client':
    clients = cls.search(host=hostname)
    if not clients:
      raise errors.UnknownHostnameError(hostname)
    if len(clients) > 1:
      raise errors.AmbiguousHostnameError(hostname, [_.id for _ in clients])
    return clients[0]

  @classmethod
  def search(cls,
             ip: Optional[Text] = None,
             mac: Optional[Text] = None,
             host: Optional[Text] = None,
             version: Optional[int] = None,
             labels: Optional[List[Text]] = None,
             user: Optional[Text] = None) -> Sequence['Client']:
    """Searches for clients specified with keywords.

    Args:
      ip: Client IP address.
      mac: Client MAC address.
      host: Client hostname.
      version: Client version.
      labels: Client labels.
      user: Client username.

    Returns:
      A sequence of clients.
    """

    def format_keyword(key: Text, value: Text) -> Text:
      return '{}:{}'.format(key, value)

    keywords = []
    if ip is not None:
      keywords.append(format_keyword('ip', ip))
    if mac is not None:
      keywords.append(format_keyword('mac', mac))
    if host is not None:
      keywords.append(format_keyword('host', host))
    if version is not None:
      keywords.append(format_keyword('client', str(version)))
    if labels:
      for label in labels:
        keywords.append(format_keyword('label', label))
    if user is not None:
      keywords.append(format_keyword('user', user))

    query = ' '.join(keywords)
    clients = _api.get().SearchClients(query)
    return representer.ClientList([cls(_) for _ in clients])

  @property
  def id(self) -> Text:
    return self._client.client_id

  @property
  def hostname(self) -> Text:
    if self._summary is not None:
      return self._summary.system_info.fqdn
    return self.knowledgebase.fqdn

  @property
  def ifaces(self) -> Sequence[jobs_pb2.Interface]:
    if self._summary is not None:
      return representer.InterfaceList(self._summary.interfaces)
    return representer.InterfaceList(self._client.data.interfaces)

  @property
  def knowledgebase(self) -> knowledge_base_pb2.KnowledgeBase:
    return self._client.data.knowledge_base

  @property
  def arch(self) -> Text:
    if self._summary is not None:
      return self._summary.system_info.machine
    return self._client.data.os_info.machine

  @property
  def kernel(self) -> Text:
    if self._summary is not None:
      return self._summary.system_info.kernel
    return self._client.data.os_info.kernel

  @property
  def labels(self) -> Sequence[Text]:
    return [_.name for _ in self._client.data.labels]

  @property
  def first_seen(self) -> datetime.datetime:
    return _microseconds_to_datetime(self._client.data.first_seen_at)

  @property
  def last_seen(self) -> datetime.datetime:
    return _microseconds_to_datetime(self._client.data.last_seen_at)

  @property
  def os(self) -> fs.FileSystem:
    return fs.FileSystem(self._client, jobs_pb2.PathSpec.OS)

  @property
  def tsk(self) -> fs.FileSystem:
    return fs.FileSystem(self._client, jobs_pb2.PathSpec.TSK)

  @property
  def registry(self) -> fs.FileSystem:
    return fs.FileSystem(self._client, jobs_pb2.PathSpec.REGISTRY)

  @property
  def cached(self) -> vfs.VFS:
    return self.os.cached

  def request_approval(self, approvers: List[Text], reason: Text) -> None:
    """Sends approval request to the client for the current user.

    Args:
      approvers: List of users who will be notified of this request.
      reason: Reason for this approval.

    Returns:
      Nothing.
    """
    if not reason:
      raise ValueError('Approval reason is not provided')
    if not approvers:
      raise ValueError('List of approvers is empty')

    self._client.CreateApproval(reason=reason, notified_users=approvers)

  def request_approval_and_wait(self, approvers: List[Text],
                                reason: Text) -> None:
    """Sends approval request and waits until it's granted.

    Args:
      approvers: List of users who will be notified of this request.
      reason: Reason for this approval.

    Returns:
      Nothing.
    """
    if not reason:
      raise ValueError('Approval reason is not provided')
    if not approvers:
      raise ValueError('List of approvers is empty')

    approval = self._client.CreateApproval(
        reason=reason, notified_users=approvers)
    approval.WaitUntilValid()

  def interrogate(self) -> jobs_pb2.ClientSummary:
    """Grabs fresh metadata about the client.

    Returns:
      A client summary.
    """
    try:
      interrogate = self._client.CreateFlow(name='Interrogate')
    except api_errors.AccessForbiddenError as e:
      raise errors.ApprovalMissingError(self.id, e)

    _timeout.await_flow(interrogate)
    self._summary = list(interrogate.ListResults())[0].payload
    return self._summary

  def ps(self) -> Sequence[sysinfo_pb2.Process]:
    """Returns a list of processes running on the client."""
    args = flows_pb2.ListProcessesArgs()

    try:
      ps = self._client.CreateFlow(name='ListProcesses', args=args)
    except api_errors.AccessForbiddenError as e:
      raise errors.ApprovalMissingError(self.id, e)

    _timeout.await_flow(ps)
    return representer.ProcessList([_.payload for _ in ps.ListResults()])

  def ls(self, path: Text, max_depth: int = 1) -> Sequence[jobs_pb2.StatEntry]:
    """Lists contents of a given directory.

    Args:
      path: A path to the directory to list the contents of.
      max_depth: Max depth of subdirectories to explore. If max_depth is >1,
        then the results will also include the contents of subdirectories (and
        sub-subdirectories and so on).

    Returns:
      A sequence of stat entries.
    """
    return self.os.ls(path, max_depth)

  def glob(self, path: Text) -> Sequence[jobs_pb2.StatEntry]:
    """Globs for files on the given client.

    Args:
      path: A glob expression (that may include `*` and `**`).

    Returns:
      A sequence of stat entries to the found files.
    """
    return self.os.glob(path)

  def grep(self, path: Text,
           pattern: bytes) -> Sequence[jobs_pb2.BufferReference]:
    """Greps for given content on the specified path.

    Args:
      path: A path to a file to be searched.
      pattern: A regular expression on search for.

    Returns:
      A list of buffer references to the matched content.
    """
    return self.os.grep(path, pattern)

  def fgrep(self, path: Text,
            literal: bytes) -> Sequence[jobs_pb2.BufferReference]:
    """Greps for given content on the specified path.

    Args:
      path: A path to a file to be searched.
      literal: A literal expression on search for.

    Returns:
      A list of buffer references to the matched content.
    """
    return self.os.fgrep(path, literal)

  def osquery(self,
              query: Text,
              timeout: int = 30000,
              ignore_stderr_errors: bool = False) -> osquery_pb2.OsqueryTable:
    """Runs given query on the client.

    Args:
      query: An SQL query to run against osquery on the client.
      timeout: Query timeout in millis.
      ignore_stderr_errors: If true, will not break in case of stderr errors.

    Returns:
      An osquery table corresponding to the result of running the query.
    """

    args = osquery_pb2.OsqueryArgs()
    args.query = query
    args.timeout_millis = timeout
    args.ignore_stderr_errors = ignore_stderr_errors

    try:
      oq = self._client.CreateFlow(name='OsqueryFlow', args=args)
    except api_errors.AccessForbiddenError as e:
      raise errors.ApprovalMissingError(self.id, e)

    _timeout.await_flow(oq)
    return list(oq.ListResults())[0].payload.table

  def collect(self,
              artifact: Text) -> Sequence[artifact_pb2.ClientActionResult]:
    """Collects specified artifact.

    Args:
      artifact: A name of the artifact to collect.

    Returns:
      A list of results that artifact collection yielded.
    """

    args = flows_pb2.ArtifactCollectorFlowArgs()
    args.artifact_list.append(artifact)
    args.apply_parsers = True

    try:
      ac = self._client.CreateFlow(name='ArtifactCollectorFlow', args=args)
    except api_errors.AccessForbiddenError as e:
      raise errors.ApprovalMissingError(self.id, e)

    _timeout.await_flow(ac)
    return [_.payload for _ in ac.ListResults()]

  def yara(
      self,
      signature: Text,
      pids: Optional[Sequence[int]] = None,
      regex: Optional[Text] = None
  ) -> Sequence[flows_pb2.YaraProcessScanResponse]:
    """Scans processes using provided YARA rule.

    Args:
      signature: YARA rule to run.
      pids: List of pids of processes to scan.
      regex: A regex to match against the process name.

    Returns:
      A list of YARA matches.
    """
    if pids is None:
      pids = []

    args = flows_pb2.YaraProcessScanRequest()
    args.yara_signature = signature
    args.ignore_grr_process = False

    if regex is not None:
      args.process_regex = regex

    args.pids.extend(pids)

    try:
      yara = self._client.CreateFlow(name='YaraProcessScan', args=args)
    except api_errors.AccessForbiddenError as e:
      raise errors.ApprovalMissingError(self.id, e)

    _timeout.await_flow(yara)
    return [_.payload for _ in yara.ListResults()]

  def wget(self, path: Text) -> Text:
    """Downloads a file and returns a link to it.

    Args:
      path: A path to download.

    Returns:
      A link to the file.
    """
    return self.os.wget(path)

  def open(self, path: Text) -> io.BufferedIOBase:
    """Opens a file object corresponding to the given path on the client.

    The returned file object is read-only.

    Args:
      path: A path to the file to open.

    Returns:
      A file-like object (implementing standard IO interface).
    """
    return self.os.open(path)

  def _repr_pretty_(self, p: pretty.PrettyPrinter, cycle: bool) -> None:
    del cycle  # Unused.
    icon = client_textify.online_icon(self._client.data.last_seen_at)
    last_seen = client_textify.last_seen(self._client.data.last_seen_at)
    data = '{icon} {id} @ {host} ({last_seen})'.format(
        icon=icon, id=self.id, last_seen=last_seen, host=self.hostname)
    p.text(data)


def _microseconds_to_datetime(ms: int) -> datetime.datetime:
  return datetime.datetime.utcfromtimestamp(ms / (10**6))
